Load Packages
library(MMCTime)
#> Warning: replacing previous import 'ape::where' by 'dplyr::where' when loading
#> 'MMCTime'
#> Warning: replacing previous import 'ape::rotate' by 'ggtree::rotate' when
#> loading 'MMCTime'
library(ggplot2)
library(ape)
library(posterior)
#> This is posterior version 1.5.0
#>
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#>
#> mad, sd, var
#> The following objects are masked from 'package:base':
#>
#> %in%, match
library(reshape2)
library(patchwork)
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following object is masked from 'package:ape':
#>
#> where
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
library(phangorn)
Ingest Data
phi_sets <- 4L
mut_sets <- 3L
set_size <- 48L
n_tips<-200
nu <- 1/10
alphas <- seq(from=0.01, to=.75, length.out=set_size)
phis <- c(0.0, 0.5, 0.75, 1.0)
mus <- c(1.5, 3.0, 6.0)
omegas <- c(0.5, 1.0, 2.0)
root_n <- "t_201"
height_n <- "rel_height"
len_n <- "rel_length"
f <- function(clock, phi_idx, run_idx)
{
idx <- (phi_idx-1L) * set_size + run_idx
res<-readRDS(paste0("./mut",clock,"/ana_out/res_", idx, ".rds"))
tmp <- as.data.frame(res$summaries[,c("variable", "median", "q5", "q95", "rhat", "ess_bulk")])
tmp$run_idx <- run_idx
tmp$phi_idx <- phi_idx
tmp$clock <- clock
g <- function(x)
{
y <- node.depth.edgelength(read.tree(paste0("./gt/tree_",x,".nwk")))
return(max(y))
}
g_len <- function(x)
{
y <- read.tree(paste0("./gt/tree_",x,".nwk"))$edge.length
return(sum(y))
}
gt_len <- g_len(idx)
gt_h <- g(idx)
tmp[tmp$variable == root_n, c("q5","median","q95")] <- (tmp[tmp$variable == root_n, c("q5","median","q95")]-gt_h)/gt_h
tmp[tmp$variable == root_n,"variable"] = height_n
tmp[tmp$variable == "tree_length", c("q5","median","q95")] <- (tmp[tmp$variable == "tree_length", c("q5","median","q95")]-gt_len)/gt_len
tmp[tmp$variable == "tree_length", "variable"] = len_n
return(tmp)
}
dfs <- do.call(rbind, lapply(1L:mut_sets,
function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
function(j) lapply(1L:set_size,
function(k) f(i,j,k)))))))
Plot Stuff
vnames <- c("phi", "alpha", "nu","omega","mu", height_n, len_n)
subset_v <- function (df, v) df[df$variable==v,]
relab <- function (df) df %>%
mutate(phi_idx = recode(phi_idx, "1" = "0.0", "2" = "0.5", "3" = "0.75", "4"="1.0"))
conv_df <- do.call(rbind, lapply(1L:mut_sets,
function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
function(j) lapply(1L:set_size,
function(k) data.frame(clock=i,phi_idx=j, run_idx=k)))))))
conv_df$conv <- apply(conv_df, 1, function(x) all((dfs[which((dfs$clock == x[1]) & (dfs$phi_idx == x[2]) & (dfs$run_idx == x[3])),]$ess_bulk>200)) &&
all(dfs[which((dfs$clock == x[1]) & (dfs$phi_idx == x[2]) & (dfs$run_idx == x[3])),]$rhat < 1.05))
conv_df <- conv_df[!conv_df$conv, ]
conv_df <- relab(conv_df)
sum_plt <- function(df, vname)
{
allDf <- df[df$variable %in% vnames,]
gt_df <- do.call(rbind, lapply(1L:mut_sets,
function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
function(j) lapply(1L:set_size,
function(k) data.frame(clock=i,phi_idx=j, run_idx=k)))))))
gt_df$phi <- apply(gt_df, 1, function(x) phis[x[2]])
gt_df$alpha <- apply(gt_df, 1, function(x) alphas[x[3]])
gt_df$mu <- apply(gt_df, 1, function(x) mus[x[1]])
gt_df$omega <- apply(gt_df, 1, function(x) omegas[x[1]])
gt_df$nu <- nu
gt_df$rel_height <- 0.0
gt_df$rel_length <- 0.0
gt_df <- melt(gt_df, measure.vars=vnames)
gt_df$variable <- factor(gt_df$variable, levels=vnames)
allDf$variable <- factor(allDf$variable, levels=vnames)
gt_df <- subset_v(gt_df, vname)
allDf <- subset_v(allDf, vname)
allDf <- relab(allDf)
gt_df <- relab(gt_df)
ggplot(allDf, aes(x=run_idx, ymin=q5, y=median, ymax=q95)) +
geom_errorbar(width=.1, , alpha=0.4) +
geom_point(size=.5, alpha=0.8) +
geom_point(data=conv_df, aes(x=run_idx, y=-Inf), color="red", size=2.0, shape=4, inherit.aes=F) +
geom_line(data=gt_df, aes(x=run_idx, y=value),color="red", inherit.aes=F) +
facet_grid(rows = vars(phi_idx), cols=vars(clock),labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
theme_minimal()+
labs(x="Run") +
coord_cartesian(clip = 'off') +
ggtitle(vname) +
theme(
axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
plot.margin = margin(0, 0, 0, 0, "cm"),
panel.grid.major = element_blank(),
axis.line = element_line(size=rel(0.2), colour = "grey80"),
plot.title = element_text(hjust = 0.5,size=rel(1.0)))
}
p <- sum_plt(dfs, "phi")
#> Warning: The `size` argument of `element_line()` is deprecated as of ggplot2 3.4.0.
#> ℹ Please use the `linewidth` argument instead.
#> This warning is displayed once every 8 hours.
#> Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
#> generated.
pdf("../manuscript_figs/kmb_phi.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
p <- sum_plt(dfs, "alpha")
pdf("../manuscript_figs/kmb_alpha.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
p <- sum_plt(dfs, "nu")
pdf("../manuscript_figs/kmb_nu.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
p <- sum_plt(dfs, "mu")
pdf("../manuscript_figs/kmb_mu.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
p <- sum_plt(dfs, "omega")
pdf("../manuscript_figs/kmb_omega.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
p <- sum_plt(dfs, height_n)
pdf(paste0("../manuscript_figs/kmb_", height_n, ".pdf"),8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
p <- sum_plt(dfs, len_n)
pdf(paste0("../manuscript_figs/kmb_", len_n, ".pdf"),8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
h <- function(clock, phi_idx, run_idx)
{
idx <- (phi_idx-1L) * set_size + run_idx
res<-readRDS(paste0("./mut",clock,"/ana_out/res_", idx, ".rds"))
q_idx <- which(colnames(res$draws) %in% paste0("q_",1:(2*n_tips-1)))
bcounts <- apply(res$draws, 1, function(x) sum(1-x[q_idx])) - 1 ##root has a q but no branch
g <- function(x)
{
nrow(di2multi(read.tree(paste0("./gt/tree_",x,".nwk")))$edge)
}
gt_br <- g(idx)
b_ci <- quantile(bcounts-gt_br, probs = c(0.025, .5, 0.975))
p_mm <- sum(((2*n_tips-2)-bcounts) > 0)/length(bcounts)
out <- c(b_ci, p_mm, run_idx, phi_idx, clock)
names(out) <- c("q5", "median", "q95", "p_mm", "run_idx", "phi_idx", "clock")
return(out)
}
df_bci <- data.frame(do.call(rbind, lapply(1L:mut_sets,
function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
function(j) lapply(1L:set_size,
function(k) h(i,j,k))))))))
df_bci <- relab(df_bci)
p2 <- ggplot(relab(df_bci), aes(x=run_idx, ymin=q5, y=median, ymax=q95)) +
geom_errorbar(width=.1, alpha=0.4) +
geom_point(size=.5, alpha=0.8) +
geom_hline(yintercept=0, color="red") +
geom_point(data=conv_df, aes(x=run_idx, y=-Inf), color="red", size=2.0, shape=4, inherit.aes=F) +
facet_grid(rows = vars(phi_idx), cols=vars(clock), labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
geom_point(data=conv_df, aes(x=run_idx, y=-Inf), color="red", size=2.0, shape=4, inherit.aes=F) +
theme_minimal()+
labs(x="Run") +
coord_cartesian(clip = 'off') +
theme(
axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
axis.text.y=element_text(size=rel(0.9)),
plot.margin = margin(0, 0, 0, 0, "cm"),
panel.grid.major = element_blank(),
axis.line = element_line(size=rel(0.2), colour = "grey80"),
plot.title = element_text(hjust = 0.5,size=rel(1.0)))
p3 <- ggplot(df_bci, aes(x=run_idx, y=p_mm)) +
geom_point(size=.5, alpha=0.8) +
geom_hline(yintercept=.99, color="red") +
geom_hline(yintercept=.95, color="red") +
geom_hline(yintercept=.90, color="red") +
geom_hline(yintercept=.75, color="red") +
geom_hline(yintercept=.5, color="red") +
geom_point(data=conv_df, aes(x=run_idx, y=0), color="red", size=2.0, shape=4, inherit.aes=F) +
facet_grid(rows = vars(phi_idx), cols=vars(clock), labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
scale_y_continuous(trans='log2', breaks = c(.99,.95,.90,.75,.5)) +
coord_cartesian(clip = 'off') +
theme_minimal()+
labs(x="Run") +
theme(
axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
axis.text.y=element_text(size=rel(0.9)),
plot.margin = margin(0, 0, 0, 0, "cm"),
panel.grid.major = element_blank(),
axis.line = element_line(size=rel(0.2), colour = "grey80"),
plot.title = element_text(hjust = 0.5,size=rel(1.0)))
Plot relative branch counts
pdf("../manuscript_figs/kmb_bcount.pdf",8,8)
plot(p2)
dev.off()
#> quartz_off_screen
#> 2
p2
Plot Posterior Probabilities of the Tree Containing Multiple Mergers
pdf("../manuscript_figs/kmb_pmm.pdf",8,8)
plot(p3)
#> Warning in scale_y_continuous(trans = "log2", breaks = c(0.99, 0.95, 0.9, :
#> log-2 transformation introduced infinite values.
dev.off()
#> quartz_off_screen
#> 2
p3
#> Warning in scale_y_continuous(trans = "log2", breaks = c(0.99, 0.95, 0.9, :
#> log-2 transformation introduced infinite values.
count_subs <- function(clock, phi_idx, run_idx)
{
idx <- (phi_idx-1L) * set_size + run_idx
tr<-read.tree(paste0("./mut",clock,"/tree_clock_", idx, ".nwk"))
exp_subs <- node.depth.edgelength(tr)[1:n_tips]
out <- c(mean(exp_subs), run_idx, clock)
names(out) <- c("exp_subs", "run_idx", "clock")
return(out)
}
subs_df <- data.frame(do.call(rbind, lapply(1L:mut_sets,
function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
function(j) lapply(1L:set_size,
function(k) count_subs(i,j,k))))))))
p4 <- ggplot(subs_df, aes(factor(clock), exp_subs/1e4L)) +
geom_violin() +
theme_minimal() +
labs(y="Expected Substitutions Per Site", x="Clock") +
theme(
axis.text.x=element_text(size=rel(1.0), angle = 45, hjust=1),
plot.margin = margin(0, 0, 0, 0, "cm"),
panel.grid.major = element_blank(),
axis.line = element_line(size=rel(0.2), colour = "grey80"),
plot.title = element_text(hjust = 0.5,size=rel(1.0)))
Plot Expected Number of Substitutions for Each Clock
pdf("../manuscript_figs/kmb_exp_subs.pdf",8,8)
plot(p4)
dev.off()
#> quartz_off_screen
#> 2
p4
Plot comparison against TreeTime and LSD2
h2 <- function(clock, phi_idx, run_idx)
{
idx <- (phi_idx-1L) * set_size + run_idx
res_mmc<-readRDS(paste0("./mut",clock,"/ana_out/res_", idx, ".rds"))
res_lsd<-readRDS(paste0("./mut",clock,"/ana_lsd/res_", idx, ".rds"))
mmc_sums <- res_mmc$summaries
lsd_tree <- res_lsd$dateNexusTree@phylo
gt_t <- di2multi(read.tree(paste0("./gt/tree_",idx,".nwk")))
gt_height <- max(node.depth.edgelength(gt_t))
gt_length <- sum(gt_t$edge.length)
gt_bcount <- nrow(gt_t$edge)
q_idx <- which(colnames(res_mmc$draws) %in% paste0("q_",1:(2*n_tips-1)))
bcounts_mmc <- (median(apply(res_mmc$draws, 1, function(x) sum(1-x[q_idx])) - 1)-gt_bcount)/gt_bcount ##root has a q but no branch
tt_res <- paste0("./mut",clock,"/ana_treetime/tree_", idx, "/timetree.nexus")
tt_tree <- NA
bcounts_tt <- NA
tl_tt <- NA
height_tt <- NA
if(file.exists(tt_res))
{
tt_tree <- read.nexus(tt_res)
tl_tt <- (sum(tt_tree$edge.length) - gt_length)/gt_length
height_tt <- (max(node.depth.edgelength(tt_tree)) - gt_height)/gt_height
bcounts_tt <- (nrow(di2multi(tt_tree)$edge)-gt_bcount)/gt_bcount
} else {
print(paste("Treetime run:",run_idx, "clock:",clock, "failed to complete." ))
}
bcounts_lsd <- (nrow(di2multi(lsd_tree)$edge)-gt_bcount)/gt_bcount
tl_lsd <- (sum(lsd_tree$edge.length) - gt_length)/gt_length
height_lsd <- (max(node.depth.edgelength(lsd_tree)) - gt_height)/gt_height
bcounts <- c(bcounts_mmc, bcounts_tt, bcounts_lsd)
tl_mmc <- unlist(mmc_sums[mmc_sums$variable == "tree_length", "median"] - gt_length)/gt_length
tl <- c(tl_mmc, tl_tt, tl_lsd)
height_mmc <- unlist(mmc_sums[mmc_sums$variable == root_n, "median"]-gt_height)/gt_height
height <- c(height_mmc, height_tt, height_lsd)
avg_dist_mmc <- mean(sapply(sample_timetree(res_mmc, n_samp=res_mmc$n_draws, replace=F), function(x) KF.dist(x, gt_t, rooted=T)))
dist <- c(
avg_dist_mmc, KF.dist(di2multi(tt_tree),gt_t, rooted=T), KF.dist(di2multi(lsd_tree),gt_t, rooted=T)
)
out <- data.frame(value=c(bcounts, tl, height, dist),
variable=c(rep("rel_branch_count",3),rep("rel_length",3),rep("rel_height",3), rep("dist",3)),
method=rep(c("ours", "treetime", "lsd2"), 4)
)
out$run_idx <- run_idx
out$phi_idx <- phi_idx
out$clock <- clock
return(out)
}
df_summs <- data.frame(do.call(rbind, lapply(1L:mut_sets,
function(i) do.call(rbind, do.call(c, lapply(1L:phi_sets,
function(j) lapply(1L:set_size,
function(k) h2(i,j,k))))))))
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
#> one tree is unrooted, unrooted both
.mod_transform <- function(y, lambda){
if(lambda != 0){
yt <- sign(y) * (((abs(y) + 1) ^ lambda - 1) / lambda)
} else {
yt = sign(y) * (log(abs(y) + 1))
}
return(yt)
}
.mod_inverse <- function(yt, lambda){
if(lambda != 0){
y <- ((abs(yt) * lambda + 1) ^ (1 / lambda) - 1) * sign(yt)
} else {
y <- (exp(abs(yt)) - 1) * sign(yt)
}
return(y)
}
prettify <- function(breaks){
# round numbers, more aggressively the larger they are
digits <- -floor(log10(abs(breaks))) + 1
digits[breaks == 0] <- 0
return(round(breaks, digits = digits))
}
mod_breaks <- function(lambda, n = 6, prettify = TRUE){
function(x){
breaks <- .mod_transform(x, lambda) %>%
pretty(n = n) %>%
.mod_inverse(lambda)
if(prettify){
breaks <- prettify(breaks)
}
return(breaks)
}
}
sum_func <- function(x)
{
x <- unname(quantile(x, c(.025, .5, .975)))
return(data.frame(y=x[2], ymin=x[1], ymax=x[3]))
}
foo <- function(df, vn)
{
df %>%
filter(variable== vn) %>%
ggplot(aes(x=method, y=value, fill=method)) +
geom_violin(aes(color=method)) +
stat_summary(fun.data=sum_func, geom="pointrange", width=0.05, color="gray35",fill="gray35") +
geom_hline(yintercept=0, color="red") +
facet_grid(cols=vars(clock), rows=vars(phi_idx), scales="free_y", labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
theme_minimal() +
labs(x="Method", y="Relative Error") +
ggtitle(vn)+
scale_y_continuous(transform=scales::transform_modulus(0), breaks=mod_breaks(lambda = 0, prettify = T))+
theme(
axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
plot.margin = margin(0, 0, 0, 0, "cm"),
panel.grid.major = element_blank(),
axis.line = element_line(size=rel(0.2), colour = "grey80"),
plot.title = element_text(hjust = 0.5,size=rel(1.0)))
}
p <- foo(df_summs, "rel_branch_count")
#> Warning in stat_summary(fun.data = sum_func, geom = "pointrange", width = 0.05,
#> : Ignoring unknown parameters: `width`
pdf("../manuscript_figs/kmb_compare_bcount.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
p <- foo(df_summs, "rel_length")
#> Warning in stat_summary(fun.data = sum_func, geom = "pointrange", width = 0.05,
#> : Ignoring unknown parameters: `width`
pdf("../manuscript_figs/kmb_compare_length.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
p <- foo(df_summs, "rel_height")
#> Warning in stat_summary(fun.data = sum_func, geom = "pointrange", width = 0.05,
#> : Ignoring unknown parameters: `width`
pdf("../manuscript_figs/kmb_compare_height.pdf",8,8)
plot(p)
dev.off()
#> quartz_off_screen
#> 2
p
p_dists <- df_summs %>%
filter(variable== "dist") %>%
ggplot(aes(x=method, y=value, fill=method)) +
geom_violin(aes(color=method)) +
stat_summary(fun.data=sum_func, geom="pointrange", width=0.05, color="gray35",fill="gray35") +
facet_grid(cols=vars(clock), rows=vars(phi_idx), scales="free_y", labeller = label_bquote("Phi Scenario"==.(phi_idx), "Clock"==.(clock))) +
theme_minimal() +
labs(x="Method", y="Branch Score Distance") +
ggtitle("dist")+
scale_y_continuous(transform=scales::transform_modulus(0), breaks=mod_breaks(lambda = 0, prettify = T))+
theme(
axis.text.x=element_text(size=rel(0.7), angle = 45, hjust=1),
plot.margin = margin(0, 0, 0, 0, "cm"),
panel.grid.major = element_blank(),
axis.line = element_line(size=rel(0.2), colour = "grey80"),
plot.title = element_text(hjust = 0.5,size=rel(1.0)))
#> Warning in stat_summary(fun.data = sum_func, geom = "pointrange", width = 0.05,
#> : Ignoring unknown parameters: `width`
pdf("../manuscript_figs/kmb_dists.pdf",8,8)
plot(p_dists)
dev.off()
#> quartz_off_screen
#> 2
p_dists
Print bias/variance
df_summs %>%
group_by(method, variable) %>%
summarise(bias = mean(value),
rmse = sqrt(mean(value**2)),
q5=quantile(value, 0.05)[[1]],
q95=quantile(value, 0.95)[[1]],
iqr90=quantile(value, 0.95)[[1]]-quantile(value, 0.05)[[1]]) %>%
print(n=100)
#> `summarise()` has grouped output by 'method'. You can override using the
#> `.groups` argument.
#> # A tibble: 12 × 7
#> # Groups: method [3]
#> method variable bias rmse q5 q95 iqr90
#> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 lsd2 dist 20.3 25.5 6.62 54.3 47.7
#> 2 lsd2 rel_branch_count -0.217 0.232 -0.344 -0.0575 0.287
#> 3 lsd2 rel_height 0.511 0.714 0.00617 1.61 1.60
#> 4 lsd2 rel_length 0.438 0.637 -0.0139 1.53 1.54
#> 5 ours dist 10.6 10.9 6.98 15.6 8.65
#> 6 ours rel_branch_count 0.0160 0.0298 -0.0151 0.0621 0.0772
#> 7 ours rel_height 0.000653 0.125 -0.181 0.220 0.401
#> 8 ours rel_length -0.00703 0.0776 -0.129 0.122 0.251
#> 9 treetime dist 3884. 85016. 6.77 60.9 54.1
#> 10 treetime rel_branch_count -0.113 0.142 -0.209 0.0652 0.274
#> 11 treetime rel_height 112. 2562. -0.430 1.65 2.08
#> 12 treetime rel_length 117. 2751. -0.335 1.25 1.59